import json
import pandas as pd
from sklearn.metrics import silhouette_score,calinski_harabasz_score,pairwise_distances
from Kernels.src.Analysis.Clustering import *
from sklearn.cluster import SpectralClustering
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import normalized_mutual_info_score, adjusted_mutual_info_score
from sklearn.metrics import adjusted_rand_score
import itertools
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.graphics.mosaicplot import mosaic
import lifelines
import statsmodels.api as sm
import itertools
from scipy import stats
import plotly.graph_objects as go
from matplotlib.offsetbox import AnchoredText
import matplotlib.font_manager as fm
import seaborn as sns
import matplotlib.colorbar as cb
fontPath = "/CTGlab/home/danilo/.fonts/arial.ttf"
#From seaborn#########
sns.set_context("paper")
sns.set_style("whitegrid")
#sns.set(font='Helvetica')
#From matplotlib########
#Font:
prop = fm.FontProperties(fname=fontPath)
fm.fontManager.addfont(fontPath)
plt.rcParams['font.family'] = prop.get_name()
#plt.rcParams['font.sans-serif'] = 'Helvetica'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.labelsize'] = 14
plt.rcParams['axes.titlesize'] = 15
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.labelspacing'] = 1.5
plt.rcParams['legend.fontsize'] = 14
# Disable LaTeX interpreter
plt.rcParams['text.usetex'] = False
#lines:
plt.rcParams['lines.linewidth'] = 1
plt.rcParams['axes.linewidth'] = 1
#plt.rcParams['axes.edgecolor'] = '#353935'
plt.rcParams['axes.edgecolor'] = '#222222'
plt.rcParams['axes.titley'] = 1.05 # y is in axes-relative coordinates.
plt.rcParams['axes.titlepad'] = 0
data_input = pd.read_csv("Input_data/UMAP_METABRIC_4cmp_IC10_full.csv", sep = ",",index_col=0)
data_input=data_input.sample(n=100,axis=0,random_state=42)
df_new_clust=pd.DataFrame()
df_new_clust.index=data_input.index
#get classical kernel
#SELECT FT
n_qubits=4
features=[]
for i in range(1,int(n_qubits)+1):
name_cna='Component_'+str(i)
features.append(name_cna)
labels = 'IntClustMemb'
print(features)
########SAMPLING###################################
y_train=data_input[labels].to_numpy()
X_train=data_input[features]
['Component_1', 'Component_2', 'Component_3', 'Component_4']
#Load original patient index from data and metabric data(metadata)
#Load clustering data
print('#### CASE 2000 samples simulated ####')
df_clust=pd.read_csv('/CTGlab/home/valeria/Quantum-Machine-Learning-for-Expression-Data/Results/Unsupervised_2000_umap/clustering_1980_clusters.csv',index_col=0)
df_clust.columns = df_clust.columns.str.replace("_Z_full", "_Z")
df_clust.replace('Z_full','Z',inplace=True)
data_input = pd.read_csv("Input_data/UMAP_METABRIC_4cmp_IC10_full.csv", sep = ",",index_col=0)
patient_index=data_input.index
#set index to df_clust to patient index
df_clust.index=patient_index
#Load metadata
metabric = pd.read_csv('/CTGlab/data/brca_metabric/data_clinical_patient.txt', sep='\t', skiprows=[0,1,2,3])
#Remove patients with missing OS_STATUS and merge with data
metabricSubset = metabric.loc[~metabric['OS_STATUS'].isna()]
quantumDf = pd.merge(metabricSubset, df_clust, how='left', right_on=[df_clust.index], left_on=['PATIENT_ID']).dropna(subset=df_clust.columns)
quantumDf['OS_STATUS_censored'] = quantumDf['VITAL_STATUS'].apply(lambda x: 1 if x == "Died of Disease" else 0)
quantumDf['OS_STATUS'] = quantumDf['OS_STATUS'].apply(lambda x: int(x.split(':')[0]))
#### CASE 2000 samples simulated ####
cluster_col=[]
for i in quantumDf.columns:
if 'Cluster' in i:
cluster_col.append(i)
quantumDf_s=quantumDf.copy()
for column in quantumDf_s.loc[:, ['INTCLUST','CLAUDIN_SUBTYPE']+cluster_col].columns:
quantumDf_s[column] = quantumDf_s[column].apply(lambda x: f"{column}_{x}")
def plotSankey(df: pd.DataFrame, cols: list, names=None,linkVar=None, valueToHighlight=None):
"""
Plot a Sankey diagram based on the given dataframe and columns.
Parameters:
- df (pd.DataFrame): The dataframe containing the data.
- cols (list): A list of column names to use for plotting.
- linkVar (str): The name of the column to use as a link variable (optional).
- valueToHighlight: The value in the link variable to highlight (optional).
Returns:
- fig: The generated plotly figure object.
"""
colsToPlot = []
for col in cols:
colsToPlot.append(list(set(x for x in df[col])))
mask = {}
toenumerate = []
for i in range(len(colsToPlot)):
toenumerate = toenumerate + colsToPlot[i]
for k,v in enumerate(toenumerate):
mask[v] = k
combinations = []
for i in range(0, len(colsToPlot)-1):
combinations = combinations + list(itertools.product(colsToPlot[i], colsToPlot[i+1]))
sources = []
targets = []
values = []
colors = []
if linkVar != None:
linkVarValues = df[linkVar].unique()
for combination in combinations:
source = mask[combination[0]]
target = mask[combination[1]]
column1 = '_'.join(combination[0].split('_')[:-1])
column2 = '_'.join(combination[1].split('_')[:-1])
if linkVar != None:
for idx,val in enumerate(linkVarValues):
flow = df.loc[(df[column1] == combination[0])&(df[column2] == combination[1])&(df[linkVar] == val)].shape[0]
color = "lightgrey" if val != valueToHighlight else "orange"
colors.append(color)
sources.append(source)
targets.append(target)
values.append(flow)
else:
flow = df.loc[(df[column1] == combination[0])&(df[column2] == combination[1])].shape[0]
colors.append('lightgrey')
sources.append(source)
targets.append(target)
values.append(flow)
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 15,
thickness = 20,
line = dict(color = "black", width = 0.5),
label = [f"{x.split('_')[-1]}" for x in toenumerate ]
),
link = dict(
source = sources,
target = targets,
value = values,
color = colors,
label=[f"{x.split('_')[-1]}" for x in toenumerate ]
))])
if names == None:
names=cols
for x_coordinate, column_name in enumerate(names):
fig.add_annotation(
x=x_coordinate,
y=1.05,
xref="x",
yref="paper",
text=column_name,
showarrow=False,
font=dict(
size=16,
color="black"
),
align="center",
)
fig.update_layout(
title='Sample Assignments in Clustering Methods' if valueToHighlight==None else f"Sample Assignments in Clustering Methods\nHighlighted {valueToHighlight}",
plot_bgcolor='white',
paper_bgcolor='white',
font_size=14,
hovermode='x',
autosize=False,
width=1000, height=600,
xaxis={
'showgrid': False, # thin lines in the background
'zeroline': False, # thick line at x=0
'visible': False, # numbers below
},
yaxis={
'showgrid': False, # thin lines in the background
'zeroline': False, # thick line at x=0
'visible': False, # numbers below
}
)
fig.show()
return fig
Let's check the highest scoring clusters for the simulatted 2000 case:
Get meta data columns to plot
quantumDf_s.columns[:24]
Index(['PATIENT_ID', 'LYMPH_NODES_EXAMINED_POSITIVE', 'NPI', 'CELLULARITY',
'CHEMOTHERAPY', 'COHORT', 'ER_IHC', 'HER2_SNP6', 'HORMONE_THERAPY',
'INFERRED_MENOPAUSAL_STATE', 'SEX', 'INTCLUST', 'AGE_AT_DIAGNOSIS',
'OS_MONTHS', 'OS_STATUS', 'CLAUDIN_SUBTYPE', 'THREEGENE',
'VITAL_STATUS', 'LATERALITY', 'RADIO_THERAPY', 'HISTOLOGICAL_SUBTYPE',
'BREAST_SURGERY', 'RFS_STATUS', 'RFS_MONTHS'],
dtype='object')
meta_plot=['INTCLUST','CLAUDIN_SUBTYPE','ER_IHC', 'HER2_SNP6','HORMONE_THERAPY','THREEGENE','HISTOLOGICAL_SUBTYPE']
quantumDf_s=quantumDf.copy()
for column in quantumDf_s.loc[:, meta_plot+cluster_col].columns:
quantumDf_s[column] = quantumDf_s[column].apply(lambda x: f"{column}_{x}")
Filter configurations
#Load clustering results
df_2000=pd.read_csv('/CTGlab/home/valeria/Quantum-Machine-Learning-for-Expression-Data/Results/Unsupervised_2000_umap/clustering_1980_opt_k_reviewed.csv',index_col=0)
df_2000.replace('Z_full','Z',inplace=True)
#Drop duplicated rbf cases due to bandwidth
df_2000=df_2000[(df_2000['ftmap'] != 'rbf') | (df_2000['Bandwidth'] == 1)]
Lets start with the silhouette
df_2000.sort_values(by='silhouette',ascending=False,inplace=True)
#Get only the 10 best with silhouette >0.3
df_2000[df_2000['silhouette']>0.3].sort_values(by='silhouette',ascending=False).head(10)
| ftmap | K | Bandwidth | s | geom_distance | concentration | silhouette | Score_cluster | CHI | DI | v_intra | v_inter | N_samples | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 135 | ZZ_linear | 2 | 0.125 | 198446.008917 | 11.510186 | 0.066276 | 0.627849 | 0.624853 | 271.129995 | 0.833597 | 0.052648 | 0.000630 | 1980 |
| 28 | rbf | 3 | 1.000 | 126427.404049 | 2.459801 | 0.028149 | 0.568687 | 0.612915 | 1002.217980 | 0.000669 | 0.008163 | 0.030924 | 1980 |
| 27 | rbf | 2 | 1.000 | 126427.404049 | 2.459801 | 0.028149 | 0.541753 | 0.548029 | 1323.383994 | 0.000420 | 0.017752 | 0.024397 | 1980 |
| 136 | ZZ_linear | 3 | 0.125 | 198446.008917 | 11.510186 | 0.066276 | 0.525904 | 0.561904 | 774.054846 | 0.043874 | 0.022971 | 0.055678 | 1980 |
| 29 | rbf | 4 | 1.000 | 126427.404049 | 2.459801 | 0.028149 | 0.427324 | 0.615735 | 1054.539243 | 0.000872 | 0.005637 | 0.030072 | 1980 |
| 30 | rbf | 5 | 1.000 | 126427.404049 | 2.459801 | 0.028149 | 0.414847 | 0.655101 | 1030.559094 | 0.000728 | 0.004759 | 0.028514 | 1980 |
| 31 | rbf | 6 | 1.000 | 126427.404049 | 2.459801 | 0.028149 | 0.408590 | 0.687778 | 949.856227 | 0.000577 | 0.004185 | 0.027981 | 1980 |
| 32 | rbf | 7 | 1.000 | 126427.404049 | 2.459801 | 0.028149 | 0.390088 | 0.717164 | 961.774043 | 0.000453 | 0.003064 | 0.027807 | 1980 |
| 152 | ZZ_linear | 10 | 0.250 | 146583.001305 | 5.823200 | 0.080595 | 0.389346 | 0.600019 | 649.065192 | 0.036875 | 0.023439 | 0.050366 | 1980 |
| 151 | ZZ_linear | 9 | 0.250 | 146583.001305 | 5.823200 | 0.080595 | 0.373357 | 0.562661 | 693.376521 | 0.035284 | 0.030125 | 0.049718 | 1980 |
Considering most of these are rbf, and thus not very usefull for us, we can proced and plot only the rbf ones to see cluster consistency between micro and macro clusters and then select one for reference to keep with the quantum
cols_to_plot_rbf=[]
cols_name_rbf=[]
for i in df_2000[(df_2000['silhouette']>0.3) & (df_2000.ftmap=='rbf')].iloc[:10].itertuples():
name='Cluster_'+i.ftmap+'_'+str((i.K))+'_'+f'{i.Bandwidth:g}'
name_col='S:{:.2f}'.format(i.silhouette)
cols_name_rbf.append(name_col)
cols_to_plot_rbf.append(name)
fig=plotSankey(quantumDf_s, cols=cols_to_plot_rbf, names=cols_name_rbf, valueToHighlight=['High'], linkVar='CELLULARITY')
We can see that overall there are two macroclusters (+1 small other) one of which then tend to stratify (cluster 0.0) and the other which seems to be more consistent. we take for reference K=7
rbf_clusters_s='Cluster_rbf_7_1'
Lets move onand take the best 5 quantum clusters
df_2000[(df_2000['silhouette']>0.3) & (df_2000.ftmap!='rbf')].iloc[:5]
| ftmap | K | Bandwidth | s | geom_distance | concentration | silhouette | Score_cluster | CHI | DI | v_intra | v_inter | N_samples | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 135 | ZZ_linear | 2 | 0.125 | 198446.008917 | 11.510186 | 0.066276 | 0.627849 | 0.624853 | 271.129995 | 0.833597 | 0.052648 | 0.000630 | 1980 |
| 136 | ZZ_linear | 3 | 0.125 | 198446.008917 | 11.510186 | 0.066276 | 0.525904 | 0.561904 | 774.054846 | 0.043874 | 0.022971 | 0.055678 | 1980 |
| 152 | ZZ_linear | 10 | 0.250 | 146583.001305 | 5.823200 | 0.080595 | 0.389346 | 0.600019 | 649.065192 | 0.036875 | 0.023439 | 0.050366 | 1980 |
| 151 | ZZ_linear | 9 | 0.250 | 146583.001305 | 5.823200 | 0.080595 | 0.373357 | 0.562661 | 693.376521 | 0.035284 | 0.030125 | 0.049718 | 1980 |
| 111 | Z | 5 | 0.500 | 139813.611325 | 2.923790 | 0.055403 | 0.365701 | 0.502482 | 977.285005 | 0.044605 | 0.018611 | 0.041749 | 1980 |
Since all the best clusters have been done with the ZZ linear we omit these in further plots to improve readability
cols_to_plot_q = []
cols_name_q = []
for i in df_2000[(df_2000['silhouette']>0.3) & (df_2000.ftmap!='rbf')].iloc[:5].itertuples():
name='Cluster_'+i.ftmap+'_'+str((i.K))+'_'+f'{i.Bandwidth:g}'
name_col='b:{},S:{:.3f}'.format(i.Bandwidth,i.silhouette)
cols_name_q.append(name_col)
cols_to_plot_q.append(name)
fig=plotSankey(quantumDf_s, cols=cols_to_plot_q, names=cols_name_q, valueToHighlight=['High'], linkVar='CELLULARITY')
As for the other cases we see once again how the stratification with lower k rapresents a macrostructure in ore data and tend to have a higher value of s due to the bias of the silhouette score. There seems to be a good consistency also between the finer stratifications (K=10,9,8), and considering they came from the same kernel (ftmap=ZZ_linear, b=0.25) we decide to keep K=10 as a reference since it yields the best silhouette.
q_cluster_s='Cluster_ZZ_linear_10_0.25'
Now we can put togheter quantum, rbf and known stratification (INTCLUST and CLAUDINE ) and look how they are :
print(q_cluster_s)
Cluster_ZZ_linear_10_0.25
fig=plotSankey(quantumDf_s, cols=[rbf_clusters_s,q_cluster_s], names=['rbf','ZZ_linear_0.25 best_s'], valueToHighlight=['High'], linkVar='CELLULARITY')
print('AMI between rbf and ZZ_linear_0.25',adjusted_mutual_info_score(quantumDf_s[rbf_clusters_s],quantumDf_s[q_cluster_s]))
AMI between rbf and ZZ_linear_0.25 0.5680031848012492
for meta in meta_plot:
print('AMI between {} and ZZ_linear_0.25'.format(meta),adjusted_mutual_info_score(quantumDf_s[meta],quantumDf_s[q_cluster_s]))
AMI between INTCLUST and ZZ_linear_0.25 -0.00026901173333492847 AMI between CLAUDIN_SUBTYPE and ZZ_linear_0.25 -9.660250617278738e-05 AMI between ER_IHC and ZZ_linear_0.25 -0.0019880083791308876 AMI between HER2_SNP6 and ZZ_linear_0.25 0.000838793203366204 AMI between HORMONE_THERAPY and ZZ_linear_0.25 -0.0006368616323537453 AMI between THREEGENE and ZZ_linear_0.25 2.5483342799545463e-05 AMI between HISTOLOGICAL_SUBTYPE and ZZ_linear_0.25 0.0006967353256035525
for meta in meta_plot:
print('AMI between {} and ZZ_linear_0.25'.format(meta),adjusted_mutual_info_score(quantumDf_s[meta],quantumDf_s[q_cluster_s]))
fig=plotSankey(quantumDf_s, cols=[q_cluster_s]+[meta] , valueToHighlight=['High'], linkVar='CELLULARITY')
AMI between INTCLUST and ZZ_linear_0.25 -0.00026901173333492847
AMI between CLAUDIN_SUBTYPE and ZZ_linear_0.25 -9.660250617278738e-05
AMI between ER_IHC and ZZ_linear_0.25 -0.0019880083791308876
AMI between HER2_SNP6 and ZZ_linear_0.25 0.000838793203366204
AMI between HORMONE_THERAPY and ZZ_linear_0.25 -0.0006368616323537453
AMI between THREEGENE and ZZ_linear_0.25 2.5483342799545463e-05
AMI between HISTOLOGICAL_SUBTYPE and ZZ_linear_0.25 0.0006967353256035525